"""Implements a continued fraction class."""

import basic_representation_types
# from decimal import Decimal as dec
from mpmath import mpf as dec, isnormal, isnan, mp
from gen_consts import gen_pi_const
import math
import scipy.stats # used for the linear regression fitting
from utils import MathOperations

class ZeroB(Exception):
    """Used as an exception in case that a ContFrac is truncated by a zero numerator along the way."""
    pass


class ContFrac(basic_representation_types.BasicContFracAlgo):
    """Generates a continued fraction whose numerator and denominator are generated by polynomials.
    The i'th iteration (a_0 is the 0'th iteration) is a_0 + b_1 / (a_1 + b_2 / (... + b_i / a_i))."""
    # if b_n = 0 then we get a rational number, and should skip this results. This is what avoid_zero_b for.
    def __init__(self, a_coeffs, b_coeffs, avoid_zero_b=True, iter_promo_matrix_generator=None,
                 first_iter_matrix=None, target_val=None, logging=False):
        """a_coeffs - [[x_1, y_1, z_1, ...], [x_2, y_2, z_2, ...], ..., [x_{n-1}, y_{n-1}, z_{n-1}, ...]] <=>
                      a_i = x_(i%n)+y_(i%n)*i+z(i%n)*i^2+...
                      if no interlace is needed, [x_1, y_1, z_1, ...] can be supplied as well.
        b_coeffs - same as a_coeffs, for b_i = ...
        avoid_zero_b - if a finite continued fraction is achieved (b_i=0 for some i>0), run stops and excpetion ZeroB
                       is thrown.
        iter_promo_matrix_generator - a method that generates a promotion matrix. Default: self._default_promo_mat_gen
                                     See help for self._default_promo_mat_gen for more information."""
        if iter_promo_matrix_generator is None:
            iter_promo_matrix_generator = self._default_promo_mat_gen
        if target_val is None:
            target_val = gen_pi_const()

        if not isinstance(a_coeffs[0], list) and not isinstance(a_coeffs[0], tuple):
            a_coeffs = [a_coeffs]
        if not isinstance(b_coeffs[0], list) and not isinstance(b_coeffs[0], tuple):
            b_coeffs = [b_coeffs]

        params = {'a_coeffs': a_coeffs, 'b_coeffs': b_coeffs, 'contfrac_res': 1}
        if first_iter_matrix is None:
            self._auto_first_promo_mat = True
            first_iter_matrix = self._autogen_first_iter_matrix(params)
        else:
            self._auto_first_promo_mat = False

        self._avoid_zero_b = avoid_zero_b
        self._approach_type, self._approach_params = 'undefined', 0
        super().__init__(params, iter_promo_matrix_generator, first_iter_matrix, target_val, logging)

    def reinitialize(self, a_coeffs=None, b_coeffs=None, first_iter_matrix=None, **kwargs):
        """Reinitialize the instance with new parameters. Useful mostly to avoid intensive instances creation."""
        self._approach_type, self._approach_params = 'undefined', 0
        params = kwargs
        if a_coeffs:
            params['a_coeffs'] = a_coeffs
        else:
            params['a_coeffs'] = self._params['a_coeffs']
        if b_coeffs:
            params['b_coeffs'] = b_coeffs
        else:
            params['b_coeffs'] = self._params['b_coeffs']
        if first_iter_matrix is None and self._auto_first_promo_mat:
            first_iter_matrix = self._autogen_first_iter_matrix(params)

        super().reinitialize(params, first_iter_matrix)

    def iteration_algorithm(self, params, state_mat, i, print_calc):
        """Generates the next iteration, using the "matrix product" of params and promo_matrix of:
        | P_{i+1} Q_{i+1} |   | a_i b_i |   | P_i     Q_i     |   | P_0    Q_0    |   | a_0 1 |
        | P_i     Q_i     | = | 1   0   | X | P_{i-1} Q_{i-1} | , | P_{-1} Q_{-1} | = | 1   0 |
        contfrac_i = P_i/Q_i

        params - {'a_coeffs': ..., 'b_coeffs': ..., 'contfrac_res': ...}
        state_mat - [[p_i, p_{i-1}], [q_i, q_{i-1}]], is actually the current state matrix. Name is only for signature
                    compatibility.
        i - iteration index, used to generate the [a_i b_i; 1 0] matrix.
        print_calc - whether to print the calculation process. For debugging."""
        i += 1
        a_coeffs = params['a_coeffs']
        b_coeffs = params['b_coeffs']
        # state_mat - the [a_i b_i; 1 0] matrix. Actually, it's just a [a_i, b_i] list
        ab_mat = self._iter_promo_matrix_generator(i, a_coeffs, b_coeffs, print_calc)
        p_vec, q_vec = state_mat

        # TODO: if calculations are slow, once in a while (i % 100 == 0?) check if gcd(p_i,p_{i-1}) != 0
        # TODO: and take it out as a factor (add p_gcd, q_gcd params)
        new_p_vec = (sum([l*m for l, m in zip(p_vec, ab_mat)]), p_vec[0])
        new_q_vec = (sum([l*m for l, m in zip(q_vec, ab_mat)]), q_vec[0])
        if dec('inf') in new_p_vec:
            raise ValueError('infinity p')
        if dec('inf') in new_q_vec:
            raise ValueError('infinity q')

        if self._logging:
            p_i = dec(new_p_vec[0])
            q_i = dec(new_q_vec[0])
            if not isnormal(q_i):
                contfrac_res_i = dec('NaN')
            else:
                contfrac_res_i = p_i / q_i
            params['contfrac_res'] = contfrac_res_i
        new_promo_mat = (new_p_vec, new_q_vec)

        return (params, new_promo_mat)

    def finalize_iterations(self, params, promo_mat, n):
        """Finalize the iteration by computing contfrac_res = p/q.
        p, q are taken from promo_mat (see self.iteration_algorithm for more info).
        contfrac_res is saved to params (see self.iteration_algorithm for more info).
        n is unused (and is here only to keep method signature)."""
        p_vec, q_vec = promo_mat
        p = dec(p_vec[0])
        q = dec(q_vec[0])
        if not isnormal(q):
            contfrac_res = dec('NaN')
        else:
            contfrac_res = p / q
        params['contfrac_res'] = contfrac_res
        return (params, promo_mat)

    def _default_promo_mat_gen(self, i, a_coeffs, b_coeffs, print_calc):
        """i - iteration index.
        print_calc - whether to print the calculation process. For debugging.
        Returns (a_i, b_i)"""
        a2p = MathOperations.subs_in_polynom
        # a_i = a2p(a_coeffs[i % len(a_coeffs)], divmod(i, len(a_coeffs))[0])
        # b_i = a2p(b_coeffs[(i-1) % len(b_coeffs)], divmod(i, len(b_coeffs))[0])
        a_i = a2p(a_coeffs[i % len(a_coeffs)], i)
        b_i = a2p(b_coeffs[i % len(b_coeffs)], i)
        if self._avoid_zero_b and b_i == 0 and i > 0:
            raise ZeroB(i)
        if print_calc:
            print('(%d/%d+)' % (b_i, a_i), end='')
        promo_mat = (a_i, b_i)
        return promo_mat

    def _autogen_first_iter_matrix(self, params):
        """Returns (a_0, 1), (1, 0) as P_0, P_-1, Q_0, Q_-1"""
        a_coeffs = params['a_coeffs']
        return (a_coeffs[0][0], 1), (1, 0)

    def _print_calc_0_depth(self):
        """Prints the first stage of the calculation: 'a_0+'"""
        print('%d+' % self._params['a_coeffs'][0][0], end='')

    def get_result(self):
        """Returns the continued fraction evaluated value (so far)."""
        return self._params_log['contfrac_res'][-1]

    def get_p_q(self):
        """Returns the latest numerator and denominator of the representation of the partial contfrac = p/q"""
        p_vec, q_vec = self.iter_matrices[-1]
        return p_vec[-1], q_vec[-1]

    def is_result_valid(self):
        """Makes sure that the result isn't zero, infinity, NaN or subnormal."""
        # return self._params_log['contfrac_res'][-1].is_normal()
        return mpmath.isnormal(self._params_log['contfrac_res'][-1])

    def estimate_approach_type_and_params(self, iters=600, initial_cutoff=200, iters_step=50):
        """See self._estimate_approach_type_and_params_inner_alg. Sorry homies :("""
        approach_type, approach_params = self._estimate_approach_type_and_params_inner_alg(find_poly_parameter=True,
                                                                                           iters=iters,
                                                                                           initial_cutoff=initial_cutoff,
                                                                                           iters_step=iters_step)
        while approach_type == 'fast' and initial_cutoff > 0:
            initial_cutoff >>= 1
            iters = max(50, iters >> 1)
            iters_step = max(int((iters-initial_cutoff) / 30), iters_step >> 1)
            approach_type, approach_params = self._estimate_approach_type_and_params_inner_alg(find_poly_parameter=True,
                                                                                               iters=iters,
                                                                                               initial_cutoff=initial_cutoff,
                                                                                               iters_step=iters_step)
        if initial_cutoff == 0:
            iters = 50
            iters_step = 1
            while approach_type == 'fast' and iters > 10:
                initial_cutoff >>= 1
                iters -= 5
                approach_type, approach_params = self._estimate_approach_type_and_params_inner_alg(find_poly_parameter=True,
                                                                                                   iters=iters,
                                                                                                   initial_cutoff=initial_cutoff,
                                                                                                   iters_step=iters_step)
        self._approach_type = approach_type
        self._approach_params = approach_params
        return

    def get_approach_type_and_params(self):
        return (self._approach_type, self._approach_params)

    def set_approach_type_and_params(self, convergence_info):
        self._approach_type, self._approach_params = convergence_info

    def is_convergence_fast(self, find_poly_parameter=False, iters=600, initial_cutoff=200, iters_step=50,
                            exponential_threshold=1.1):
        """Returns true if the convergence type is exponential or over exponential.
False if it's sub exponential (e.g. linear)."""
        # This is an old function. It's still in use, but it's yucky and I don't feel like going through it and
        # neither should you :(
        # Hopefully it will be replaced with some more theoretical insights.
        # HOWEVER, IT IS GENERAL AND MAY BE USEFUL FOR NEW, FUTURE REPRESENTATION METHODS.
        if iters_step < 6:
            ValueError('iters_step should be at least 4')

        self.gen_iterations(initial_cutoff, exec_finalize=False)

        return_val = True
        for i in range(initial_cutoff+iters_step, iters+1, iters_step):
            p_0, q_0 = self.get_p_q()
            self.add_iterations(1, exec_finalize=False)
            p_1, q_1 = self.get_p_q()
            self.add_iterations(1, exec_finalize=False)
            p_2, q_2 = self.get_p_q()
            self.add_iterations(1, exec_finalize=False)
            p_3, q_3 = self.get_p_q()
            self.add_iterations(1, exec_finalize=False)
            p_4, q_4 = self.get_p_q()
            self.add_iterations(1, exec_finalize=False)
            p_5, q_5 = self.get_p_q()

            # In the case of a (super-)exponential convergence, the ratio of two subsequent iteration should approach to
            # a constant. This is a derivation of this fact to work with integers. Further details about this claim can
            # be found in the "convergence_analysis.pdf" file.
            # q_4(p_2q_0-p_0q_2) > q_0(p_4q_2-p_2q_4)
            lhs_pair = abs(q_4 * (p_2 * q_0 - p_0 * q_2))
            rhs_pair = abs(q_0 * (p_4 * q_2 - p_2 * q_4))
            lhs_odd = abs(q_5 * (p_3 * q_1 - p_1 * q_3))
            rhs_odd = abs(q_1 * (p_5 * q_3 - p_3 * q_5))
            if (((lhs_pair - rhs_pair) <= (rhs_pair >> 4) ) or
                    ((lhs_odd - rhs_odd) <= (rhs_odd >> 4))):
                return_val = False
                break
            # -3 for the iterations of res_1, res_2, res_3 that were already executed
            self.add_iterations(iters_step - 5, exec_finalize=False)

        self.gen_iterations(0)
        return return_val

    def _estimate_approach_type_and_params_inner_alg(self, find_poly_parameter=False, iters=5000, initial_cutoff=1500,
                                          iters_step=500, exponential_threshold=1.1):
        """Returns 'exp', 'super_exp', 'poly2sympoly', 'undefined', 'fast' and 'mixed', as a tuple of (string,num):
        (approach_type, approach_parameter) or ('poly2sympoly', (approach_parameter, R**2))."""
        # This is an old function. It's still in use, but it's yucky and I don't feel like going through it and
        # document it. Sorry :(
        # Hopefully it will be replaced with some theoretical insights.
        # HOWEVER, IT IS GENERAL AND MAY BE USEFUL FOR NEW, FUTURE REPRESENTATION METHODS.
        # See the "convergence_analysis.pdf" file for further details about theory behind this function.
        # It evaluates and categorizes by computing linear fitting to the error/iterations plot. Regular or logarithmic.

        effective_zero = dec(10)**-mp.dps

        if iters_step < 6:
            ValueError('iters_step should be at least 4')

        approach_type = None
        approach_parameter = 0

        delta_pair = []
        delta_odd = []
        self.gen_iterations(initial_cutoff)
        res_0 = self.get_result()
        self.add_iterations(1)
        res_1 = self.get_result()
        self.add_iterations(1)
        res_2 = self.get_result()
        self.add_iterations(1)
        res_3 = self.get_result()
        self.add_iterations(1)
        res_4 = self.get_result()
        self.add_iterations(1)
        res_5 = self.get_result()
        delta_pair.append((initial_cutoff, abs(res_2 - res_0)))
        delta_pair.append((initial_cutoff + 2, abs(res_4 - res_2)))
        delta_odd.append((initial_cutoff + 1, abs(res_3 - res_1)))
        delta_odd.append((initial_cutoff + 3, abs(res_5 - res_3)))

        for i in range(initial_cutoff+iters_step, iters+1, iters_step):
            # -3 for the iterations of res_1, res_2, res_3 that were already executed
            self.add_iterations(iters_step - 5)
            res_0 = self.get_result()
            self.add_iterations(1)
            res_1 = self.get_result()
            self.add_iterations(1)
            res_2 = self.get_result()
            self.add_iterations(1)
            res_3 = self.get_result()
            self.add_iterations(1)
            res_4 = self.get_result()
            self.add_iterations(1)
            res_5 = self.get_result()
            delta_pair.append((i, abs(res_2 - res_0)))
            delta_pair.append((i + 2, abs(res_4 - res_2)))
            delta_odd.append((i + 1, abs(res_3 - res_1)))
            delta_odd.append((i + 3, abs(res_5 - res_3)))

        pair_diminish = False
        odd_diminish = False
        if len(delta_pair) > 3 and all([ abs(p[1]) < effective_zero  for p in delta_pair[-3:] ]):
            pair_diminish = True
        if len(delta_odd) > 3 and all([ abs(p[1]) < effective_zero for p in delta_odd[-3:] ]):
            odd_diminish = True
        # if one diminishes and the other isn't, return 'undefined'
        if pair_diminish ^ odd_diminish:
            approach_type = 'undefined'
        elif pair_diminish and odd_diminish:
            approach_type = 'fast'

        # if approach_type:
        #     return (approach_type, approach_parameter)

        pair_ratio = [ (delta_pair[i][0], delta_pair[i][1] / delta_pair[i+1][1])
                       for i in range(0, len(delta_pair), 2) if delta_pair[i][1] != 0 and delta_pair[i+1][1] != 0 and
                       not isnan(delta_pair[i][1]) and not isnan(delta_pair[i+1][1]) ]
        odd_ratio = [ (delta_odd[i][0], delta_odd[i][1] / delta_odd[i+1][1])
                      for i in range(0, len(delta_odd), 2) if delta_odd[i][1] != 0 and delta_odd[i+1][1] != 0 and
                      not isnan(delta_odd[i][1]) and not isnan(delta_odd[i+1][1]) ]

        if len(pair_ratio) < 6:
            return (approach_type, approach_parameter)

        mean_pair_ratio = sum([ p for i, p in pair_ratio] ) / len(pair_ratio)
        mean_pair_ratio_avg_square_error = sum([ (r-mean_pair_ratio)**2 for i, r in pair_ratio ]).sqrt() / len(pair_ratio)
        mean_odd_ratio = sum([ p for i, p in odd_ratio ]) / len(odd_ratio)
        mean_odd_ratio_avg_square_error = sum([ (r-mean_odd_ratio)**2 for i, r in odd_ratio ]).sqrt() / len(odd_ratio)
        relative_pair_sq_err = mean_pair_ratio_avg_square_error / mean_pair_ratio
        relative_odd_sq_err = mean_odd_ratio_avg_square_error / mean_odd_ratio
        if relative_pair_sq_err > 0.5 or relative_odd_sq_err > 0.5:
            if all([ i[1] > 2 for i in (pair_ratio[3*int(len(pair_ratio)/4):] +
                                         odd_ratio[3*int(len(odd_ratio)/4):]) ]):

                approach_type = 'super_exp'
            else:
                approach_type = 'undefined'
        if (relative_odd_sq_err <= 0.5 and relative_pair_sq_err <= 0.5) or approach_type == 'super_exp':
            is_pair_exp = mean_pair_ratio > exponential_threshold
            is_odd_exp = mean_odd_ratio > exponential_threshold
            # in case one is exponential and the other isn't return 'mixed'
            if is_pair_exp ^ is_odd_exp:
                approach_type = 'mixed'
            elif is_pair_exp and is_odd_exp:
                if approach_type != 'super_exp':
                    approach_type = 'exp'
                approach_parameter_pair = mean_pair_ratio**type(mean_pair_ratio)(0.5)
                approach_parameter_odd = mean_odd_ratio**type(mean_odd_ratio)(0.5)
                approach_parameter = min(approach_parameter_pair, approach_parameter_odd)
                approach_coeff_pair_list = [ abs(delta_pair[i][1] * approach_parameter**(delta_pair[i][0]) /
                                        (1 - approach_parameter**(-2))) for i in range(0, len(delta_pair))
                                         if delta_pair[i][1] != 0 and not isnan(delta_pair[i][1]) ]
                approach_coeff_pair = sum(approach_coeff_pair_list) / len(approach_coeff_pair_list)
                approach_coeff_odd_list = [ abs(delta_odd[i][1] * approach_parameter**(delta_odd[i][0]) /
                                       (1 - approach_parameter**(-2))) for i in range(0, len(delta_odd))
                                       if delta_odd[i][1] != 0 and not isnan(delta_odd[i][1]) ]
                approach_coeff_odd = sum(approach_coeff_odd_list) / len(approach_coeff_odd_list)
                approach_coeff = min(approach_coeff_pair, approach_coeff_odd)
                approach_parameter = (approach_parameter, approach_coeff)
            else:
                approach_type = 'poly2sympoly'

        if approach_type != 'poly2sympoly' or not find_poly_parameter:
            return (approach_type, approach_parameter)

    #     We're requested to find the poly2sympoly parameter
        log_x_pair = [ math.log(i) for i, d in delta_pair ]
        log_y_pair = [ math.log(d) for i, d in delta_pair ]
        slope_pair, intercept_pair, r_value_pair, p_value_pair, std_err_pair = scipy.stats.linregress(log_x_pair,
                                                                                                      log_y_pair)
        log_x_odd = [ math.log(i) for i, d in delta_odd ]
        log_y_odd = [ math.log(d) for i, d in delta_odd ]
        slope_odd, intercept_odd, r_value_odd, p_value_odd, std_err_odd = scipy.stats.linregress(log_x_odd, log_y_odd)

        # TODO: replace the -1 by *0.95? need to make sure first what is the slop (is it negative?)
        approach_parameter = (min(abs(slope_pair), abs(slope_odd))-1, min(intercept_pair, intercept_odd),
                              min(r_value_pair**2, r_value_odd**2))
        return (approach_type, approach_parameter)


def eval_contfrac(a, b=None):
    """a - a list of additive constant (first element) + denominator values (following elements)
    b - a list of numerator values. If ommited, 1's are used.
    If a and b are given and are of the same length, the additive constant is assumed to be 0 and is all denominators,
    and the result is
        b[0]/(a[0]+b[1]/...)
   otherwise the result is
       a[0]+b[0]/(a[1]+b[1]/...)"""
    if b is None:
        b = [1] * (len(a)-1)
    if len(a) == len(b):
        res = 1
    elif len(a) == len(b) + 1:
        res = a[-1]
        a = a[:-1]
    else:
        raise ValueError('different lengths for a and b')

    for a_val, b_val in zip(reversed(a), reversed(b)):
        res = a_val + b_val / res
    return res


def eval_contfrac_by_polys(a, b, calculation_depth):
    """a, b are given as polynomials with an option to interlaced polynomials. See documentation for ContFrac for more
    details.
    Calculates the continued fraction into 'calculation_depth' depth."""
    if isinstance(a[0], (int, dec)):
        a = [a]
    if isinstance(b[0], (int, dec)):
        b = [b]

    return eval_contfrac([MathOperations.subs_in_polynom(a[i%len(a)], i) for i in range(calculation_depth)],
                         [MathOperations.subs_in_polynom(b[i%len(b)], i) for i in range(1, calculation_depth)])


def eval_dec_contfrac_by_polys(a, b, calculation_depth):
    """a, b are given as polynomials with an option to interlaced polynomials. See documentation for ContFrac for more
    details.
    Calculates the continued fraction into 'calculation_depth' depth.
    Calculations are done with Decimal, which enables arbitrary precision. Please set the Decimal context precision as
    needed."""
    if isinstance(a[0], (int, dec)):
        a = [a]
    if isinstance(b[0], (int, dec)):
        b = [b]

    a = [ [ dec(i) for i in p ] for p in a ]
    b = [ [ dec(i) for i in p ] for p in b ]
    return eval_contfrac([ MathOperations.subs_in_polynom(a[i % len(a)], i) for i in range(calculation_depth) ],
                         [ MathOperations.subs_in_polynom(b[i % len(b)], i) for i in range(1, calculation_depth) ])